#!/usr/bin/env python3
"""
Rockwell AIM-65 Audio Cassette Tape Decoder
===================================
Decodes audio WAV files recorded from a Rockwell AIM-65 microcomputer's
cassette tape interface. Supports both Object and Text data formats.

Encoding scheme (from Appendix F of the AIM-65 User's Guide):
  - FSK at 2400 / 1200 Hz
  - Each bit = 4 half-cycles
  - Logic "1" = 4 half-cycles of 2400 Hz
  - Logic "0" = 1 half-cycle of 2400 Hz + 3 of 1200 Hz
  - Bits are LSB first


Block format: 32 SYN (0x16) | '#' (0x23) | BLK | 80 data bytes | 2 checksum

Object first block data: FILENAME(5) + CR(1) + object data(74)
Text first block data:   FILENAME(5) + text data(75)


Written by Claude Opus 4.6 Extended
Directed by J R Casey Bralla



Requires: Python 3.6+, scipy
Usage: python3 aim65_decode.py <wavfile>
"""

import sys
import wave
import array
import math


Version     = "1.03"
VersionDate = "March 15, 2026"

# ---------------------------------------------------------------------------
# Audio signal processing
# ---------------------------------------------------------------------------

def bandpass_filter(samples, sample_rate):
    """Butterworth bandpass filter to isolate FSK tones."""
    try:
        from scipy.signal import butter, sosfilt
        sos = butter(4, [800, 3200], btype='bandpass', fs=sample_rate, output='sos')
        return list(sosfilt(sos, samples))
    except ImportError:
        print("ERROR: scipy is required.  Install: pip install scipy",
              file=sys.stderr)
        sys.exit(1)


def detect_crossings(samples):
    """Zero-crossings with 10% hysteresis. Returns (position, direction) list."""
    max_amp = max(abs(s) for s in samples) if samples else 0
    if max_amp == 0:
        return []
    threshold = max_amp * 0.10
    state = 0
    crossings = []
    for i in range(len(samples)):
        if state == 0 and samples[i] > threshold:
            state = 1
            crossings.append(i)
        elif state == 1 and samples[i] < -threshold:
            state = 0
            crossings.append(i)
    return crossings


# ---------------------------------------------------------------------------
# Segmentation
# ---------------------------------------------------------------------------

def segment_half_cycles(hc_durs, sample_rate):
    """Split half-cycle stream into per-block segments.

    Uses adaptive gap detection: finds a natural gap threshold between
    the normal half-cycle durations (~9-19 samples at 44.1kHz) and the
    inter-block gaps (hundreds to thousands of samples).

    Returns only segments large enough to contain a full block."""
    if not hc_durs:
        return []

    # A valid AIM-65 block has 116 bytes × 8 bits × 4 hc/bit = 3712 hc.
    # We accept segments with at least 2000 hc as potentially valid blocks.
    MIN_BLOCK_HC = 2000

    # Find a good gap threshold by looking at the distribution of durations.
    # Normal half-cycles are 8-20 samples at 44.1kHz; gaps are 100+.
    # Use a threshold well above the longest valid 1200 Hz half-cycle.
    # At 44.1kHz: 1200Hz half = ~18.4 samples. Use 4× as threshold.
    gap_threshold = max(int(sample_rate / 300), 60)  # ~147 at 44.1kHz

    # Split at gaps
    segments = []
    current = []
    for dur in hc_durs:
        if dur > gap_threshold:
            if current:
                segments.append(current)
            current = []
        else:
            current.append(dur)
    if current:
        segments.append(current)

    # Keep only segments large enough to be a real block
    return [s for s in segments if len(s) >= MIN_BLOCK_HC]


# ---------------------------------------------------------------------------
# Bit and byte decoding
# ---------------------------------------------------------------------------

def decode_hc_to_bytes(hc_durs, sample_rate):
    """Decode half-cycle durations into raw bytes (LSB first).

    Uses a simple 4-hc-per-bit grouping from the start.
    Returns a list of byte values."""
    if len(hc_durs) < 32:
        return []

    short_dur = sample_rate / 4800.0
    long_dur = sample_rate / 2400.0
    thresh = (short_dur + long_dur) / 2.0

    # Decode bits: 4 half-cycles per bit
    bits = []
    pos = 0
    while pos + 3 < len(hc_durs):
        group = hc_durs[pos:pos + 4]
        n_long = sum(1 for d in group if d >= thresh)
        bits.append(0 if n_long >= 2 else 1)
        pos += 4

    # Assemble bytes (LSB first)
    raw_bytes = []
    for i in range(len(bits) // 8):
        val = 0
        for j in range(8):
            val |= (bits[i * 8 + j] << j)
        raw_bytes.append(val)

    return raw_bytes


def find_block_starts_in_hc(hc_durs, sample_rate):
    """Find the half-cycle positions where data blocks begin.

    The inter-block carrier is pure 2400 Hz (all short half-cycles).
    Data blocks contain 1200 Hz tones (long half-cycles) for 0-bits.
    SYN (0x16) has 0-bits, so each block's start is marked by the
    first long half-cycle after a stretch of carrier.

    Returns a list of half-cycle indices where blocks likely start."""
    short_dur = sample_rate / 4800.0
    long_dur = sample_rate / 2400.0
    thresh = (short_dur + long_dur) / 2.0

    # Minimum consecutive short half-cycles to count as carrier gap.
    # Even a short gap (leader ramp-up finish to first block) should
    # have some carrier.  Use ~20 hc = 5 bits worth of carrier.
    MIN_CARRIER_RUN = 20

    block_starts = []
    short_run = 0
    for i, dur in enumerate(hc_durs):
        if dur < thresh:
            short_run += 1
        else:
            if short_run >= MIN_CARRIER_RUN:
                # Transition from carrier to data.  The block's SYN bytes
                # start a few half-cycles before this long half-cycle
                # (the first bit of SYN byte 0x16 is bit0=0, whose pattern
                # is 1 short + 3 long, so the first hc of the bit is short).
                # Back up to include the short hc that starts the first 0-bit.
                start = max(0, i - 1)
                block_starts.append(start)
            short_run = 0

    return block_starts


def decode_block_at(hc_durs, start_pos, sample_rate):
    """Decode a single block starting near hc position start_pos.

    Tries bit-phase offsets 0-3 and picks the one that yields the
    most SYN (0x16) bytes, then decodes and returns the raw byte
    sequence for that block.

    A complete block is 115 bytes = 115*8*4 = 3680 half-cycles."""
    BLOCK_HC = 115 * 8 * 4  # 3680 hc for one complete block
    # Take enough half-cycles for one block plus some margin
    margin = 64
    end_pos = min(start_pos + BLOCK_HC + margin, len(hc_durs))
    segment = hc_durs[start_pos:end_pos]

    if len(segment) < BLOCK_HC // 2:
        return []

    short_dur = sample_rate / 4800.0
    long_dur = sample_rate / 2400.0
    thresh = (short_dur + long_dur) / 2.0

    best_offset = 0
    best_syn_count = 0
    best_bytes = []

    for offset in range(4):
        # Decode bits
        bits = []
        pos = offset
        while pos + 3 < len(segment):
            group = segment[pos:pos + 4]
            n_long = sum(1 for d in group if d >= thresh)
            bits.append(0 if n_long >= 2 else 1)
            pos += 4

        # Assemble bytes
        raw_bytes = []
        for i in range(len(bits) // 8):
            val = 0
            for j in range(8):
                val |= (bits[i * 8 + j] << j)
            raw_bytes.append(val)

        # Count leading SYN bytes
        syn_count = 0
        for b in raw_bytes:
            if b == 0x16:
                syn_count += 1
            elif syn_count > 0:
                break

        if syn_count > best_syn_count:
            best_syn_count = syn_count
            best_offset = offset
            best_bytes = raw_bytes

    return best_bytes


# ---------------------------------------------------------------------------
# Block parsing
# ---------------------------------------------------------------------------

def parse_block(raw_bytes):
    """Parse a decoded block.  Returns the 79-byte data area.
    Block format: SYN... | # | BLK | 79 data bytes | 2 checksum bytes (LE)
    Checksum covers BLK + 79 data bytes."""
    syn_end = 0
    while syn_end < len(raw_bytes) and raw_bytes[syn_end] == 0x16:
        syn_end += 1
    if syn_end < 3:
        return None
    if syn_end >= len(raw_bytes) or raw_bytes[syn_end] != 0x23:
        return None
    if syn_end + 1 >= len(raw_bytes):
        return None

    blk_num = raw_bytes[syn_end + 1]
    ds = syn_end + 2
    data_79 = raw_bytes[ds:ds + 79]
    if len(data_79) < 79:
        data_79.extend([0] * (79 - len(data_79)))

    # Block checksum (2 bytes after 79 data bytes, little-endian)
    # Checksum covers BLK + 79 data bytes
    stored_chk = None
    if ds + 81 <= len(raw_bytes):
        stored_chk = raw_bytes[ds + 79] | (raw_bytes[ds + 80] << 8)  # little-endian
    computed_chk = ((blk_num & 0xFF) + sum(data_79)) & 0xFFFF

    return {
        'syn_count': syn_end,
        'blk_num': blk_num,
        'data_80': data_79,  # keep key name for compatibility
        'block_checksum_ok': (stored_chk == computed_chk) if stored_chk is not None else None,
    }


# ---------------------------------------------------------------------------
# Object record parsing
# ---------------------------------------------------------------------------

def parse_object_records(data_bytes):
    """Parse object code records from combined data."""
    records = []
    pos = 0
    while pos < len(data_bytes):
        if data_bytes[pos] == 0x00:
            pos += 1
            continue
        if data_bytes[pos] != 0x3B:
            pos += 1
            continue
        pos += 1
        if pos >= len(data_bytes):
            break
        byte_count = data_bytes[pos]
        pos += 1
        if byte_count == 0x00:
            if pos + 4 <= len(data_bytes):
                rec_count = (data_bytes[pos] << 8) | data_bytes[pos + 1]  # big-endian (data format)
                rec_chk = (data_bytes[pos + 2] << 8) | data_bytes[pos + 3]  # big-endian
                pos += 4
                if pos < len(data_bytes) and data_bytes[pos] == 0x0D:
                    pos += 1
                records.append({'type': 'last', 'record_count': rec_count,
                                'checksum': rec_chk})
            break
        needed = 2 + byte_count + 2
        if pos + needed > len(data_bytes):
            break
        addr_hi = data_bytes[pos]
        addr_lo = data_bytes[pos + 1]
        address = (addr_hi << 8) | addr_lo
        pos += 2
        data = list(data_bytes[pos:pos + byte_count])
        pos += byte_count
        rec_chk = (data_bytes[pos] << 8) | data_bytes[pos + 1]  # big-endian
        pos += 2
        chk_bytes = [byte_count, addr_hi, addr_lo] + data
        computed_chk = sum(chk_bytes) & 0xFFFF
        if pos < len(data_bytes) and data_bytes[pos] == 0x0D:
            pos += 1
        records.append({
            'type': 'data', 'address': address, 'byte_count': byte_count,
            'data': data, 'checksum': rec_chk,
            'computed_checksum': computed_chk,
            'checksum_ok': rec_chk == computed_chk,
        })
    return records


# ---------------------------------------------------------------------------
# Main decoder
# ---------------------------------------------------------------------------

def decode_wav_file(filename):
    """Decode an AIM-65 audio cassette WAV file."""
    try:
        w = wave.open(filename, 'rb')
    except Exception as e:
        print(f"Error opening WAV file: {e}", file=sys.stderr)
        return None

    channels = w.getnchannels()
    sampwidth = w.getsampwidth()
    sample_rate = w.getframerate()
    nframes = w.getnframes()
    raw_frames = w.readframes(nframes)
    w.close()
    duration = nframes / sample_rate

    if sampwidth == 1:
        samples = [float(b - 128) for b in raw_frames]
    elif sampwidth == 2:
        samples = [float(s) for s in array.array('h', raw_frames)]
    elif sampwidth == 4:
        samples = [float(s) for s in array.array('i', raw_frames)]
    else:
        print(f"Error: unsupported sample width ({sampwidth})", file=sys.stderr)
        return None
    if channels > 1:
        samples = samples[::channels]

    # Filter entire signal
    filtered = bandpass_filter(samples, sample_rate)

    # Detect crossings and compute half-cycles
    crossings = detect_crossings(filtered)
    if len(crossings) < 200:
        print("Error: insufficient signal", file=sys.stderr)
        return None
    hc_durs = [crossings[i] - crossings[i - 1] for i in range(1, len(crossings))]

    # Segment at inter-block gaps (keeps only large segments)
    segments = segment_half_cycles(hc_durs, sample_rate)
    if not segments:
        print("Error: no valid data segments found", file=sys.stderr)
        return None

    # Decode blocks from each segment.
    #
    # Real hardware recordings have silence gaps between blocks, producing
    # one segment per block (~3680-3720 hc each).  Software-generated WAVs
    # may use continuous 2400 Hz carrier between blocks (no silence), putting
    # all blocks into a single large segment.
    #
    # Strategy:
    #   - Small segment (≤ 5000 hc): treat as one block, try phase offsets.
    #   - Large segment (> 5000 hc): scan for carrier-to-data transitions
    #     and decode each block with its own phase alignment.
    ONE_BLOCK_MAX_HC = 5000  # well above 3680 hc for one block + margin

    blocks = []
    for seg in segments:
        if len(seg) <= ONE_BLOCK_MAX_HC:
            # Single-block segment — try phase offsets directly
            raw_bytes = decode_block_at(seg, 0, sample_rate)
            if raw_bytes:
                block = parse_block(raw_bytes)
                if block:
                    blocks.append(block)
        else:
            # Multi-block segment — find carrier-to-data transitions
            block_starts = find_block_starts_in_hc(seg, sample_rate)
            if not block_starts:
                # Fallback: try as single block anyway
                raw_bytes = decode_block_at(seg, 0, sample_rate)
                if raw_bytes:
                    block = parse_block(raw_bytes)
                    if block:
                        blocks.append(block)
                continue

            for bs in block_starts:
                raw_bytes = decode_block_at(seg, bs, sample_rate)
                if raw_bytes:
                    block = parse_block(raw_bytes)
                    if block:
                        blocks.append(block)

    if not blocks:
        print("Error: no valid blocks found", file=sys.stderr)
        return None

    # Sort blocks by block number
    blocks.sort(key=lambda b: b['blk_num'])

    # Determine file type from first block
    first_80 = blocks[0]['data_80']
    file_name = ''.join(chr(b) if 32 <= b < 127 else '.' for b in first_80[0:5])
    format_byte = first_80[5] if len(first_80) > 5 else 0
    is_object = (format_byte == 0x0D)
    file_type = "Object" if is_object else "Text"

    # Block data is always 79 bytes
    use_bytes = 79

    def collect_data(blocks, is_obj):
        """Collect payload from each block's 79-byte data area."""
        data = []
        for block in blocks:
            d = block['data_80'][:use_bytes]
            if block['blk_num'] == 0:
                if is_obj:
                    data.extend(d[6:])   # skip filename(5) + CR(1)
                else:
                    data.extend(d[5:])   # skip filename(5) only
            else:
                data.extend(d)
        return data

    all_data = collect_data(blocks, is_object)

    # Build presentable block list
    for block in blocks:
        block['data'] = block['data_80'][:use_bytes]

    result = {
        'file_name': file_name,
        'file_type': file_type,
        'is_object': is_object,
        'blocks': blocks,
        'all_data': all_data,
        'wav_info': {
            'sample_rate': sample_rate, 'channels': channels,
            'sample_width': sampwidth, 'duration': duration,
        },
        'segments': len(segments),
        'half_cycles_count': len(hc_durs),
        'bytes_per_block': use_bytes,
    }

    if is_object:
        records = parse_object_records(all_data)
        result['records'] = records
        data_recs = [r for r in records if r['type'] == 'data']
        result['data_length'] = sum(r['byte_count'] for r in data_recs)
        if data_recs:
            result['start_address'] = min(r['address'] for r in data_recs)
            result['end_address'] = max(r['address'] + r['byte_count'] - 1
                                        for r in data_recs)
    else:
        text_data = []
        prev_was_cr = False
        for b in all_data:
            if b == 0x1A:
                break
            if b == 0x00:
                break
            if b == 0x0D and prev_was_cr:
                break
            prev_was_cr = (b == 0x0D)
            text_data.append(b)
        if text_data and text_data[-1] == 0x0D:
            text_data.pop()
        result['text_data'] = text_data
        result['data_length'] = len(text_data)

    return result


# ---------------------------------------------------------------------------
# Output
# ---------------------------------------------------------------------------

def print_hex_line(data, address, bpl=16):
    hex_part = ' '.join(f'{b:02X}' for b in data)
    ascii_part = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data)
    print(f"    {address:04X}: {hex_part:<{bpl * 3 - 1}s}  |{ascii_part}|")


def print_results(result):
    print()
    print("=" * 65)
    print("  Rockwell AIM-65 Audio Cassette Tape Decoder")
    print("  Version " + Version + "  " + VersionDate)
    print("=" * 65)
    print()
    print(f"  1) File Name:   \"{result['file_name']}\"")
    print(f"  2) File Type:   {result['file_type']}")
    if result['is_object']:
        print(f"  3) File Length: {result['data_length']} bytes"
              f"  (${result.get('start_address', 0):04X}"
              f" - ${result.get('end_address', 0):04X})")
    else:
        print(f"  3) File Length: {result['data_length']} bytes")
    print()

    # Block summary
    print(f"  Blocks: {len(result['blocks'])}")
    for block in result['blocks']:
        print(f"    Block ${block['blk_num']:02X}:  SYN={block['syn_count']}")
    print()

    if result['is_object']:
        records = result.get('records', [])
        data_recs = [r for r in records if r['type'] == 'data']
        ok = sum(1 for r in data_recs if r['checksum_ok'])
        total = len(data_recs)
        failed = total - ok
        if failed == 0:
            print(f"  Record checksums: all {total} OK")
        else:
            print(f"  Record checksums: {ok}/{total} OK, {failed} FAIL")
        print()
        print("  4) Object Data:")
        print()
        for rec in records:
            if rec['type'] == 'data':
                chk_str = "OK" if rec['checksum_ok'] else "FAIL"
                d = rec['data']
                hex_part = ' '.join(f'{b:02X}' for b in d)
                ascii_part = ''.join(chr(b) if 32 <= b < 127 else '.' for b in d)
                print(f"    {rec['address']:04X}: {hex_part}  |{ascii_part}|"
                      f"  [{chk_str}]")
            elif rec['type'] == 'last':
                print()
                print(f"    End-of-file ({rec['record_count']} records total)")
    else:
        print("  4) Text Data:")
        print()
        text_bytes = result.get('text_data', [])
        line = []
        for b in text_bytes:
            if b == 0x0D:
                # Carriage return — print line
                print("    " + ''.join(line))
                line = []
            elif b == 0x0A:
                # Linefeed — skip (typically follows CR)
                pass
            elif 32 <= b < 127:
                line.append(chr(b))
            elif b == 0x09:
                line.append('\t')
            else:
                line.append(f'[${b:02X}]')
        if line:
            print("    " + ''.join(line))

    print()
    print("-" * 65)
    wi = result['wav_info']
    print(f"  WAV: {wi['sample_rate']} Hz, "
          f"{wi['sample_width'] * 8}-bit, "
          f"{'stereo' if wi['channels'] > 1 else 'mono'}, "
          f"{wi['duration']:.2f}s")
    print(f"  Segments: {result['segments']}  |  "
          f"Half-cycles: {result['half_cycles_count']}")
    print("=" * 65)
    print()


def dump_raw_data(result, output_filename):
    """Write raw decoded data to a file.
    Object files: binary data bytes.
    Text files: text content (CR/LF line endings preserved)."""
    if result['is_object']:
        # Collect all object data bytes in address order
        records = result.get('records', [])
        data_recs = [r for r in records if r['type'] == 'data']
        if not data_recs:
            print("No data records to dump.", file=sys.stderr)
            return
        data_recs.sort(key=lambda r: r['address'])
        with open(output_filename, 'wb') as f:
            for rec in data_recs:
                f.write(bytes(rec['data']))
        total = sum(len(r['data']) for r in data_recs)
        print(f"  Wrote {total} bytes to {output_filename}")
    else:
        text_bytes = result.get('text_data', [])
        with open(output_filename, 'wb') as f:
            f.write(bytes(text_bytes))
        print(f"  Wrote {len(text_bytes)} bytes to {output_filename}")


def main():
    import argparse
    parser = argparse.ArgumentParser(
        description="AIM-65 Audio Cassette Tape Decoder\n\n"
                    "Decodes audio WAV files recorded from a Rockwell AIM-65\n"
                    "microcomputer's cassette tape interface.\n"
                    "Supports both Object and Text data file formats.\n"
                    "Requires: scipy (pip install scipy)",
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('wavfile', help='Input WAV file to decode')
    parser.add_argument('-o', '--output', metavar='FILE',
                        help='Dump raw decoded data to FILE '
                             '(binary for object, text for text files)')
    args = parser.parse_args()

    result = decode_wav_file(args.wavfile)
    if result is None:
        print("\nDecoding failed.", file=sys.stderr)
        sys.exit(1)
    print_results(result)

    if args.output:
        dump_raw_data(result, args.output)


if __name__ == '__main__':
    main()
